{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MF7BncmmLBeO"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn import datasets\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as tt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**DISCLAIMER**\n",
    "\n",
    "The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how flow matching works, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RKsmjLumL5A2"
   },
   "source": [
    "## Dataset: Digits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\\{0, 1, \\ldots, 16\\}$.\n",
    "\n",
    "The goal of using this dataset is that everyone can run it on a laptop, without any gpu etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hSWUnXAYLLif"
   },
   "outputs": [],
   "source": [
    "class Digits(Dataset):\n",
    "    \"\"\"Scikit-Learn Digits dataset.\"\"\"\n",
    "\n",
    "    def __init__(self, mode='train', transforms=None):\n",
    "        digits = load_digits()\n",
    "        if mode == 'train':\n",
    "            self.data = digits.data[:1000].astype(np.float32)\n",
    "        elif mode == 'val':\n",
    "            self.data = digits.data[1000:1350].astype(np.float32)\n",
    "        else:\n",
    "            self.data = digits.data[1350:].astype(np.float32)\n",
    "\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.data[idx]\n",
    "        if self.transforms:\n",
    "            sample = self.transforms(sample)\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qSP2qiMqMICK"
   },
   "source": [
    "## Flow Matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FlowMatching(nn.Module):\n",
    "    def __init__(self, vnet, sigma, D, T, stochastic_euler=False, prob_path=\"icfm\"):\n",
    "        super(FlowMatching, self).__init__()\n",
    "\n",
    "        print('Flow Matching by JT.')\n",
    "\n",
    "        self.vnet = vnet\n",
    "\n",
    "        self.time_embedding = nn.Sequential(nn.Linear(1, D), nn.Tanh())\n",
    "        \n",
    "        # other params\n",
    "        self.D = D\n",
    "        \n",
    "        self.T = T\n",
    "\n",
    "        self.sigma = sigma\n",
    "        \n",
    "        self.stochastic_euler = stochastic_euler\n",
    "        \n",
    "        assert prob_path in [\"icfm\", \"fm\"], f\"Error: The probability path could be either Independent CFM (icfm) or Lipman's Flow Matching (fm) but {prob_path} was provided.\"\n",
    "        self.prob_path = prob_path\n",
    "        \n",
    "        self.PI = torch.from_numpy(np.asarray(np.pi))\n",
    "    \n",
    "    def log_p_base(self, x, reduction='sum', dim=1):\n",
    "        log_p = -0.5 * torch.log(2. * self.PI) - 0.5 * x**2.\n",
    "        if reduction == 'mean':\n",
    "            return torch.mean(log_p, dim)\n",
    "        elif reduction == 'sum':\n",
    "            return torch.sum(log_p, dim)\n",
    "        else:\n",
    "            return log_p\n",
    "    \n",
    "    def sample_base(self, x_1):\n",
    "        # Gaussian base distribution\n",
    "        if self.prob_path == \"icfm\":\n",
    "            return torch.randn_like(x_1)\n",
    "        elif self.prob_path == \"fm\":\n",
    "            return torch.randn_like(x_1)\n",
    "        else:\n",
    "            return None\n",
    "    \n",
    "    def sample_p_t(self, x_0, x_1, t):\n",
    "        if self.prob_path == \"icfm\":\n",
    "            mu_t = (1. - t) * x_0 + t * x_1\n",
    "            sigma_t = self.sigma\n",
    "        elif self.prob_path == \"fm\":\n",
    "            mu_t = t * x_1\n",
    "            sigma_t = t * self.sigma - t + 1.\n",
    "        \n",
    "        x = mu_t + sigma_t * torch.randn_like(x_1)\n",
    "        \n",
    "        return x\n",
    "    \n",
    "    def conditional_vector_field(self, x, x_0, x_1, t):\n",
    "        if self.prob_path == \"icfm\":\n",
    "            u_t = x_1 - x_0\n",
    "        elif self.prob_path == \"fm\":\n",
    "            u_t = (x_1 - (1. - self.sigma) * x) / (1. - (1. - self.sigma) * t)\n",
    "        \n",
    "        return u_t\n",
    "\n",
    "    def forward(self, x_1, reduction='mean'):\n",
    "        # =====Flow Matching\n",
    "        # =====\n",
    "        # z ~ q(z), e.g., q(z) = q(x_0) q(x_1), q(x_0) = base, q(x_1) = empirical\n",
    "        # t ~ Uniform(0, 1)\n",
    "        x_0 = self.sample_base(x_1)  # sample from the base distribution (e.g., Normal(0,I))\n",
    "        t = torch.rand(size=(x_1.shape[0], 1))\n",
    "        \n",
    "        # =====\n",
    "        # sample from p(x|z)\n",
    "        x = self.sample_p_t(x_0, x_1, t)  # sample independent rv \n",
    "\n",
    "        # =====\n",
    "        # invert interpolation, i.e., calculate vector field v(x,t)\n",
    "        t_embd = self.time_embedding(t)\n",
    "        v = self.vnet(x + t_embd)\n",
    "        \n",
    "        # =====\n",
    "        # conditional vector field\n",
    "        u_t = self.conditional_vector_field(x, x_0, x_1, t)\n",
    "\n",
    "        # =====LOSS: Flow Matching\n",
    "        FM_loss = torch.pow(v - u_t, 2).mean(-1)\n",
    "        \n",
    "        # Final LOSS\n",
    "        if reduction == 'sum':\n",
    "            loss = FM_loss.sum()\n",
    "        else:\n",
    "            loss = FM_loss.mean()\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def sample(self,  batch_size=64):\n",
    "        # Euler method\n",
    "        # sample x_0 first\n",
    "        x_t = self.sample_base(torch.empty(batch_size, self.D))\n",
    "        \n",
    "        # then go step-by-step to x_1 (data)        \n",
    "        ts = torch.linspace(0., 1., self.T)\n",
    "        delta_t = ts[1] - ts[0]\n",
    "        \n",
    "        for t in ts[1:]:\n",
    "            t_embedding = self.time_embedding(torch.Tensor([t]))\n",
    "            x_t = x_t + self.vnet(x_t + t_embedding) * delta_t\n",
    "            # Stochastic Euler method\n",
    "            if self.stochastic_euler:\n",
    "                x_t = x_t + torch.randn_like(x_t) * delta_t\n",
    "        \n",
    "        x_final = torch.tanh(x_t)\n",
    "        return x_final\n",
    "    \n",
    "    def log_prob(self, x_1, reduction='mean'):\n",
    "        # backward Euler (see Appendix C in Lipman's paper)\n",
    "        ts = torch.linspace(1., 0., self.T)\n",
    "        delta_t = ts[1] - ts[0]\n",
    "        \n",
    "        for t in ts:\n",
    "            if t == 1.:\n",
    "                x_t = x_1 * 1.\n",
    "                f_t = 0.\n",
    "            else:\n",
    "                # Calculate phi_t\n",
    "                t_embedding = self.time_embedding(torch.Tensor([t]))\n",
    "                x_t =x_t - self.vnet(x_t + t_embedding) * delta_t\n",
    "                \n",
    "                # Calculate f_t\n",
    "                # approximate the divergence using the Hutchinson trace estimator and the autograd\n",
    "                self.vnet.eval()  # set the vector field net to evaluation\n",
    "                \n",
    "                x = torch.FloatTensor(x_t.data)  # copy the original data (it doesn't require grads!)\n",
    "                x.requires_grad = True \n",
    "                \n",
    "                e = torch.randn_like(x)  # epsilon ~ Normal(0, I) \n",
    "                \n",
    "                e_grad = torch.autograd.grad(self.vnet(x).sum(), x, create_graph=True)[0]\n",
    "                e_grad_e = e_grad * e\n",
    "                f_t = e_grad_e.view(x.shape[0], -1).sum(dim=1)\n",
    "\n",
    "                self.vnet.eval()  # set the vector field net to train again\n",
    "        \n",
    "        log_p_1 = self.log_p_base(x_t, reduction='sum') - f_t\n",
    "        \n",
    "        if reduction == \"mean\":\n",
    "            return log_p_1.mean()\n",
    "        elif reduction == \"sum\":\n",
    "            return log_p_1.sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vUoPkTmrMVnx"
   },
   "source": [
    "## Evaluation and Training functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JvwmRoi7MVto"
   },
   "source": [
    "**Evaluation step, sampling and curve plotting**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JHx4RIqDLZe9"
   },
   "outputs": [],
   "source": [
    "def evaluation(test_loader, name=None, model_best=None, epoch=None):\n",
    "    # EVALUATION\n",
    "    if model_best is None:\n",
    "        # load best performing model\n",
    "        model_best = torch.load(name + '.model')\n",
    "\n",
    "    model_best.eval()\n",
    "    loss = 0.\n",
    "    N = 0.\n",
    "    for indx_batch, test_batch in enumerate(test_loader):\n",
    "        loss_t = -model_best.log_prob(test_batch, reduction='sum')\n",
    "        loss = loss + loss_t.item()\n",
    "        N = N + test_batch.shape[0]\n",
    "    loss = loss / N\n",
    "\n",
    "    if epoch is None:\n",
    "        print(f'FINAL LOSS: nll={loss}')\n",
    "    else:\n",
    "        print(f'Epoch: {epoch}, val nll={loss}')\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "def samples_real(name, test_loader):\n",
    "    # REAL-------\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = next(iter(test_loader)).detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name+'_real_images.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def samples_generated(name, data_loader, extra_name='', T=None):\n",
    "    # GENERATIONS-------\n",
    "    model_best = torch.load(name + '.model')\n",
    "    model_best.eval()\n",
    "    \n",
    "    if T is not None:\n",
    "        model_best.T = T\n",
    "\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = model_best.sample(batch_size=num_x * num_y)\n",
    "    x = x.detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "def plot_curve(name, nll_val):\n",
    "    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel('nll')\n",
    "    plt.savefig(name + '_nll_val_curve.pdf', bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "umU3VYKzMbDt"
   },
   "source": [
    "**Training step**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NxkUZ1xVLbm_"
   },
   "outputs": [],
   "source": [
    "def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):\n",
    "    nll_val = []\n",
    "    best_nll = 1000.\n",
    "    patience = 0\n",
    "\n",
    "    # Main loop\n",
    "    for e in range(num_epochs):\n",
    "        # TRAINING\n",
    "        model.train()\n",
    "        for indx_batch, batch in enumerate(training_loader):\n",
    "            loss = model.forward(batch)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer.step()\n",
    "\n",
    "        # Validation\n",
    "        loss_val = evaluation(val_loader, model_best=model, epoch=e)\n",
    "        nll_val.append(loss_val)  # save for plotting\n",
    "\n",
    "        if e == 0:\n",
    "            print('saved!')\n",
    "            torch.save(model, name + '.model')\n",
    "            best_nll = loss_val\n",
    "        else:\n",
    "            if loss_val < best_nll:\n",
    "                print('saved!')\n",
    "                torch.save(model, name + '.model')\n",
    "                best_nll = loss_val\n",
    "                patience = 0\n",
    "\n",
    "                # samples_generated(name, val_loader, extra_name=\"_epoch_\" + str(e))\n",
    "            else:\n",
    "                patience = patience + 1\n",
    "        \n",
    "        samples_generated(name, val_loader, extra_name=\"_epoch_\" + str(e))\n",
    "        \n",
    "        if patience > max_patience:\n",
    "            break\n",
    "\n",
    "    nll_val = np.asarray(nll_val)\n",
    "\n",
    "    return nll_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0BXJ9dN0MinB"
   },
   "source": [
    "## Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KsF7f-Q-MkWu"
   },
   "source": [
    "**Initialize datasets**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms = tt.Lambda(lambda x: 2. * (x / 17.) - 1.)  # changing to [-1, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fqZKMNM0LdQ1"
   },
   "outputs": [],
   "source": [
    "train_data = Digits(mode='train', transforms=transforms)\n",
    "val_data = Digits(mode='val', transforms=transforms)\n",
    "test_data = Digits(mode='test', transforms=transforms)\n",
    "\n",
    "training_loader = DataLoader(train_data, batch_size=32, shuffle=True)\n",
    "val_loader = DataLoader(val_data, batch_size=32, shuffle=False)\n",
    "test_loader = DataLoader(test_data, batch_size=32, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6lEKUznpMns7"
   },
   "source": [
    "**Hyperparameters**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ANQo7LrGLjIN"
   },
   "outputs": [],
   "source": [
    "prob_path = \"fm\"\n",
    "\n",
    "D = 64   # input dimension\n",
    "\n",
    "M = 512  # the number of neurons in scale (s) and translation (t) nets\n",
    "\n",
    "sigma = 0.1\n",
    "\n",
    "T = 100\n",
    "\n",
    "lr = 1e-3 # learning rate\n",
    "num_epochs = 1000 # max. number of epochs\n",
    "max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-7APXeunMrDh"
   },
   "source": [
    "**Creating a folder for results**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bjSUn1eWLkWm"
   },
   "outputs": [],
   "source": [
    "name = prob_path + '_' + str(T)\n",
    "result_dir = 'results/' + name + '/'\n",
    "if not (os.path.exists(result_dir)):\n",
    "    os.mkdir(result_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Hpwm6LWUMulQ"
   },
   "source": [
    "**Initializing the model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "FrnNsCqQLmK3",
    "outputId": "5f0cf2b1-0a96-4f5c-da9e-f78f909a5259"
   },
   "outputs": [],
   "source": [
    "nnet = nn.Sequential(nn.Linear(D, M), nn.SELU(),\n",
    "                     nn.Linear(M, M), nn.SELU(),\n",
    "                     nn.Linear(M, M), nn.SELU(),\n",
    "                     nn.Linear(M, D), nn.Hardtanh(min_val=-3., max_val=3.))\n",
    "\n",
    "# Eventually, we initialize the full model\n",
    "model = FlowMatching(nnet, sigma=sigma, T=T, D=D, stochastic_euler=False, prob_path=prob_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3SzTemY3NSxO"
   },
   "source": [
    "**Optimizer - here we use Adamax**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "R9TZtLVtLoWc"
   },
   "outputs": [],
   "source": [
    "# OPTIMIZER\n",
    "optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dNf__W_ONVHA"
   },
   "source": [
    "**Training loop**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KhqHgluGLqIC",
    "outputId": "c52fa1e4-3376-4bff-9f87-6f03613c4e42"
   },
   "outputs": [],
   "source": [
    "# Training procedure\n",
    "nll_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,\n",
    "                       training_loader=training_loader, val_loader=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-3XTxgEcNXfp"
   },
   "source": [
    "**The final evaluation**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "okK1mV_-LrRU",
    "outputId": "4664693f-742d-4453-94cf-d051d2efa9be"
   },
   "outputs": [],
   "source": [
    "test_loss = evaluation(name=result_dir + name, test_loader=test_loader)\n",
    "f = open(result_dir + name + '_test_loss.txt', \"w\")\n",
    "f.write(str(test_loss))\n",
    "f.close()\n",
    "\n",
    "samples_real(result_dir + name, test_loader)\n",
    "samples_generated(result_dir + name, test_loader, extra_name='FINAL')\n",
    "\n",
    "plot_curve(result_dir + name, nll_val)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "vae_priors.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
